Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial prototype for vits2 #2838

Closed
wants to merge 2 commits into from
Closed

Conversation

p0p4k
Copy link
Contributor

@p0p4k p0p4k commented Aug 4, 2023

To do some testing these are the notebooks at my other repo.
I still have some questions about the techniques used in the paper; they do not go into the specifics. Would be great if I can get some help.

@lexkoro thanks for good discussions.

@p0p4k
Copy link
Contributor Author

p0p4k commented Aug 4, 2023

#2828

Copy link
Member

@erogol erogol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like WIP. Pls, tag me again when you need a review.

@@ -430,3 +430,133 @@ def forward(self, x, x_mask):
x = self.norm_layers_2[i](x + y)
x = x * x_mask
return x

class ConditionalRelativePositionTransformer(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better if you move this into layers/vits2

return int((kernel_size * dilation - dilation) / 2)


class TextEncoder(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess they use regular transformer not relative position one.

@p0p4k
Copy link
Contributor Author

p0p4k commented Aug 5, 2023

Ahh, yes. I forgot to turn it into draft. I'll let you know as soon as it's done. Thanks!

@p0p4k p0p4k marked this pull request as draft August 5, 2023 11:25
@manmay-nakhashi
Copy link
Collaborator

    def forward_mas(self, outputs, z_p, m_p, logs_p, x, x_mask, y_mask, g, lang_emb, noise_scale=0.01):
        # find the alignment path
        attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
        with torch.no_grad():
            o_scale = torch.exp(-2 * logs_p)
            logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1]).unsqueeze(-1)  # [b, t, 1]
            logp2 = torch.einsum("klm, kln -> kmn", [o_scale, -0.5 * (z_p**2)])
            logp3 = torch.einsum("klm, kln -> kmn", [m_p * o_scale, z_p])
            logp4 = torch.sum(-0.5 * (m_p**2) * o_scale, [1]).unsqueeze(-1)  # [b, t, 1]
            logp = logp2 + logp3 + logp1 + logp4
            
            # Adding Gaussian noise to the log probabilities
            epsilon = torch.randn_like(logp) * torch.std(logp) * noise_scale
            logp += epsilon
            
            attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()  # [b, 1, t, t']

        # duration predictor
        attn_durations = attn.sum(3)
        if self.args.use_sdp:
            loss_duration = self.duration_predictor(
                x.detach() if self.args.detach_dp_input else x,
                x_mask,
                attn_durations,
                g=g.detach() if self.args.detach_dp_input and g is not None else g,
                lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
            )
            loss_duration = loss_duration / torch.sum(x_mask)
        else:
            attn_log_durations = torch.log(attn_durations + 1e-6) * x_mask
            log_durations = self.duration_predictor(
                x.detach() if self.args.detach_dp_input else x,
                x_mask,
                g=g.detach() if self.args.detach_dp_input and g is not None else g,
                lang_emb=lang_emb.detach() if self.args.detach_dp_input and lang_emb is not None else lang_emb,
            )
            loss_duration = torch.sum((log_durations - attn_log_durations) ** 2, [1, 2]) / torch.sum(x_mask)
        outputs["loss_duration"] = loss_duration
        
        return outputs, attn

class VITS():
self.noise_decay=2e-6
self.noise_scale=0.01
we can update noise_scale -= noise_decay in on_train_step_start.
@p0p4k

@p0p4k p0p4k marked this pull request as ready for review September 6, 2023 06:17
@erogol
Copy link
Member

erogol commented Sep 7, 2023

@p0p4k If you don't mind, we need some testing like

class TestVits(unittest.TestCase):

@erogol
Copy link
Member

erogol commented Sep 27, 2023

@p0p4k all looks good to me. Can you add a recipe for people to get started with training VITS2 ?

@p0p4k
Copy link
Contributor Author

p0p4k commented Sep 27, 2023

@erogol Ok, I will do some cleanup , remove unused stuff and add a recipe with test soon. Little busy with other stuff! Thanks!

@erogol
Copy link
Member

erogol commented Sep 28, 2023

@p0p4k sure take your time...

@stale
Copy link

stale bot commented Oct 28, 2023

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. You might also look our discussion channels.

@stale stale bot added the wontfix This will not be worked on but feel free to help. label Oct 28, 2023
@stale stale bot closed this Nov 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
wontfix This will not be worked on but feel free to help.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants